import os
import torch
import requests
from typing import List
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer

from _models.constants import *

hf_token = os.environ.get("HF_TOKEN")
curr_tokenizer = None
curr_model = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model(model="BAAI/bge-small-en-v1.5", device=None):
    global curr_model, curr_tokenizer  # Assuming globals are still necessary
    if curr_model is None or curr_model.name_or_path != model:
        curr_tokenizer = AutoTokenizer.from_pretrained(model)
        curr_model = AutoModel.from_pretrained(model, trust_remote_code=True)
        curr_model.eval()  # Set the model to evaluation mode
        print(f"Setting tokenizer to {model}. ")
        print(f"Setting model to {model}. ")

    if device is not None:
        curr_model = curr_model.to(device)  # Move model to device only if specified
    return curr_model, curr_tokenizer


def get_huggingface_embeddings_batched(
    prompts: List[str],
    batch_size=4,
    model_name="BAAI/bge-small-en-v1.5",
    pbar=True,
    use_gpu=True,
):
    device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
    model, tokenizer = load_model(model_name, device)

    num_batches = max(int(len(prompts) / batch_size + 0.99), 1)
    embeddings = []

    for i in tqdm(range(num_batches), disable=not pbar):
        batch = prompts[i * batch_size : (i + 1) * batch_size]
        encoded_input = tokenizer(
            batch, padding=True, truncation=True, return_tensors="pt"
        ).to(device)
        with torch.no_grad():
            model_output = model(**encoded_input)
            sentence_embeddings = model_output[0][:, 0]
            sentence_embeddings = torch.nn.functional.normalize(
                sentence_embeddings, p=2, dim=1
            )
            embeddings.append(
                sentence_embeddings.cpu()
            )  # Move embeddings to CPU after computation

        # Optional: Clear the CUDA cache after processing each batch
        torch.cuda.empty_cache()

    return torch.vstack(embeddings).detach().numpy()


def get_huggingface_embedding(prompt: str, model="BAAI/bge-small-en-v1.5"):
    model, tokenizer = load_model(model)
    model = model.to(device)

    encoded_input = tokenizer(
        prompt, padding=True, truncation=True, return_tensors="pt"
    ).to(device)
    with torch.no_grad():
        model_output = model(**encoded_input)
        sentence_embeddings = model_output[0][:, 0]
        sentence_embeddings = torch.nn.functional.normalize(
            sentence_embeddings, p=2, dim=1
        )
        sentence_embeddings = sentence_embeddings[0]
    return sentence_embeddings.detach().cpu().numpy()


def get_huggingface_embeddings_api(prompts, model="BAAI/bge-small-en-v1.5"):
    api_url = (
        f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model}"
    )
    headers = {"Authorization": f"Bearer {hf_token}"}
    response = requests.post(
        api_url,
        headers=headers,
        json={"inputs": prompts, "options": {"wait_for_model": True}},
    )

    if response.status_code == 200:
        embeddings_matrix = response.json()
        # Convert the list of embeddings to a torch tensor for normalization
        embeddings_tensor = torch.tensor(embeddings_matrix)
        # Normalize the embeddings
        normalized_embeddings = torch.nn.functional.normalize(
            embeddings_tensor, p=2, dim=1
        )
        # Convert to numpy array and return
        print("generated embeddings")
        return normalized_embeddings.numpy()

    else:
        response.raise_for_status()


def get_huggingface_response(prompt, system_prompt=None, model="facebook/opt-125m"):
    model, tokenizer = load_model(model)
    model = model.to(device)

    encoded_input = tokenizer(
        prompt, padding=True, truncation=True, return_tensors="pt"
    ).to(device)
    with torch.no_grad():
        model_output = model.generate(**encoded_input)
    response = tokenizer.batch_decode(model_output, skip_special_tokens=True)
    return response


def get_huggingface_response_batched(
    prompts,
    system_prompt=None,
    batch_size=128,
    model="facebook/opt-125m",
    pbar=True,
):
    model, tokenizer = load_model(model)
    model = model.to(device)

    num_batches = max(int(len(prompts) / batch_size + 0.99), 1)
    responses = []
    for i in tqdm(range(num_batches), disable=not pbar):
        batch = prompts[i * batch_size : (i + 1) * batch_size]
        encoded_input = tokenizer(
            batch, padding=True, truncation=True, return_tensors="pt"
        ).to(device)
        with torch.no_grad():
            model_output = model.generate(**encoded_input)
        response = tokenizer.batch_decode(model_output, skip_special_tokens=True)
        responses.extend(response)
    return responses
